{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ACbjNjyO4f_8"
      },
      "source": [
        "##### Copyright 2019 The TensorFlow Hub Authors.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MCM50vaM4jiK"
      },
      "outputs": [],
      "source": [
        "# Copyright 2018 The TensorFlow Hub Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# =============================================================================="
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9qOVy-_vmuUP"
      },
      "source": [
        "# Semantic Search with Approximate Nearest Neighbors and Text Embeddings\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MfBg1C5NB3X0"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/hub/tutorials/tf2_semantic_approximate_nearest_neighbors\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/hub/tutorials/tf2_semantic_approximate_nearest_neighbors.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/hub/tutorials/tf2_semantic_approximate_nearest_neighbors.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/hub/tutorials/tf2_semantic_approximate_nearest_neighbors.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://tfhub.dev/google/nnlm-en-dim128/2\"><img src=\"https://www.tensorflow.org/images/hub_logo_32px.png\" />See TF Hub model</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3T4d77AJaKte"
      },
      "source": [
        "This tutorial illustrates how to generate embeddings from a [TensorFlow Hub](https://tfhub.dev) (TF-Hub) model given input data, and build an approximate nearest neighbours (ANN) index using the extracted embeddings. The index can then be used for real-time similarity matching and retrieval.\n",
        "\n",
        "When dealing with a large corpus of data, it's not efficient to perform exact matching by scanning the whole repository to find the most similar items to a given query in real-time. Thus, we use an approximate similarity matching algorithm which allows us to trade off a little bit of accuracy in finding exact nearest neighbor matches for a significant boost in speed.\n",
        "\n",
        "In this tutorial, we show an example of real-time text search over a corpus of news headlines to find the headlines that are most similar to a query. Unlike keyword search, this captures the semantic similarity encoded in the text embedding.\n",
        "\n",
        "The steps of this tutorial are:\n",
        "1. Download sample data.\n",
        "2. Generate embeddings for the data using a TF-Hub model\n",
        "3. Build an ANN index for the embeddings\n",
        "4. Use the index for similarity matching\n",
        "\n",
        "We use [Apache Beam](https://beam.apache.org/documentation/programming-guide/) to generate the embeddings from the TF-Hub model. We also use Spotify's [ANNOY](https://github.com/spotify/annoy) library to build the approximate nearest neighbor index."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nM17v_mEVSnd"
      },
      "source": [
        "### More models\n",
        "For models that have the same architecture but were trained on a different language, refer to [this](https://tfhub.dev/google/collections/nnlm/1) collection. [Here](https://tfhub.dev/s?module-type=text-embedding) you can find all text embeddings that are currently hosted on [tfhub.dev](https://tfhub.dev/). "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q0jr0QK9qO5P"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "whMRj9qeqed4"
      },
      "source": [
        "Install the required libraries."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qmXkLPoaqS--"
      },
      "outputs": [],
      "source": [
        "!pip install apache_beam\n",
        "!pip install 'scikit_learn~=0.23.0'  # For gaussian_random_matrix.\n",
        "!pip install annoy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A-vBZiCCqld0"
      },
      "source": [
        "Import the required libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6NTYbdWcseuK"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "import pickle\n",
        "from collections import namedtuple\n",
        "from datetime import datetime\n",
        "import numpy as np\n",
        "import apache_beam as beam\n",
        "from apache_beam.transforms import util\n",
        "import tensorflow as tf\n",
        "import tensorflow_hub as hub\n",
        "import annoy\n",
        "from sklearn.random_projection import gaussian_random_matrix"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tx0SZa6-7b-f"
      },
      "outputs": [],
      "source": [
        "print('TF version: {}'.format(tf.__version__))\n",
        "print('TF-Hub version: {}'.format(hub.__version__))\n",
        "print('Apache Beam version: {}'.format(beam.__version__))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "P6Imq876rLWx"
      },
      "source": [
        "## 1. Download Sample Data\n",
        "\n",
        "[A Million News Headlines](https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/SYBGZL#) dataset contains news headlines published over a period of 15 years sourced from the reputable Australian Broadcasting Corp. (ABC). This news dataset has a summarised historical record of noteworthy events in the globe from early-2003 to end-2017 with a more granular focus on Australia. \n",
        "\n",
        "**Format**: Tab-separated two-column data: 1) publication date and 2) headline text. We are only interested in the headline text.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OpF57n8e5C9D"
      },
      "outputs": [],
      "source": [
        "!wget 'https://dataverse.harvard.edu/api/access/datafile/3450625?format=tab&gbrecs=true' -O raw.tsv\n",
        "!wc -l raw.tsv\n",
        "!head raw.tsv"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Reeoc9z0zTxJ"
      },
      "source": [
        "For simplicity, we only keep the headline text and remove the publication date"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "INPWa4upv_yJ"
      },
      "outputs": [],
      "source": [
        "!rm -r corpus\n",
        "!mkdir corpus\n",
        "\n",
        "with open('corpus/text.txt', 'w') as out_file:\n",
        "  with open('raw.tsv', 'r') as in_file:\n",
        "    for line in in_file:\n",
        "      headline = line.split('\\t')[1].strip().strip('\"')\n",
        "      out_file.write(headline+\"\\n\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5-oedX40z6o2"
      },
      "outputs": [],
      "source": [
        "!tail corpus/text.txt"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2AngMtH50jNb"
      },
      "source": [
        "## 2. Generate Embeddings for the Data.\n",
        "\n",
        "In this tutorial, we use the [Neural Network Language Model (NNLM)](https://tfhub.dev/google/nnlm-en-dim128/2) to generate embeddings for the headline data. The sentence embeddings can then be easily used to compute sentence level meaning similarity. We run the embedding generation process using Apache Beam."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F_DvXnDB1pEX"
      },
      "source": [
        "### Embedding extraction method"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yL7OEY1E0A35"
      },
      "outputs": [],
      "source": [
        "embed_fn = None\n",
        "\n",
        "def generate_embeddings(text, model_url, random_projection_matrix=None):\n",
        "  # Beam will run this function in different processes that need to\n",
        "  # import hub and load embed_fn (if not previously loaded)\n",
        "  global embed_fn\n",
        "  if embed_fn is None:\n",
        "    embed_fn = hub.load(model_url)\n",
        "  embedding = embed_fn(text).numpy()\n",
        "  if random_projection_matrix is not None:\n",
        "    embedding = embedding.dot(random_projection_matrix)\n",
        "  return text, embedding\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g6pXBVxsVUbm"
      },
      "source": [
        "### Convert to tf.Example method"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JMjqjWZNVVzd"
      },
      "outputs": [],
      "source": [
        "def to_tf_example(entries):\n",
        "  examples = []\n",
        "\n",
        "  text_list, embedding_list = entries\n",
        "  for i in range(len(text_list)):\n",
        "    text = text_list[i]\n",
        "    embedding = embedding_list[i]\n",
        "\n",
        "    features = {\n",
        "        'text': tf.train.Feature(\n",
        "            bytes_list=tf.train.BytesList(value=[text.encode('utf-8')])),\n",
        "        'embedding': tf.train.Feature(\n",
        "            float_list=tf.train.FloatList(value=embedding.tolist()))\n",
        "    }\n",
        "  \n",
        "    example = tf.train.Example(\n",
        "        features=tf.train.Features(\n",
        "            feature=features)).SerializeToString(deterministic=True)\n",
        "  \n",
        "    examples.append(example)\n",
        "  \n",
        "  return examples"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gDiV4uQCVYGH"
      },
      "source": [
        "### Beam pipeline"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jCGUIB172m2G"
      },
      "outputs": [],
      "source": [
        "def run_hub2emb(args):\n",
        "  '''Runs the embedding generation pipeline'''\n",
        "\n",
        "  options = beam.options.pipeline_options.PipelineOptions(**args)\n",
        "  args = namedtuple(\"options\", args.keys())(*args.values())\n",
        "\n",
        "  with beam.Pipeline(args.runner, options=options) as pipeline:\n",
        "    (\n",
        "        pipeline\n",
        "        | 'Read sentences from files' >> beam.io.ReadFromText(\n",
        "            file_pattern=args.data_dir)\n",
        "        | 'Batch elements' >> util.BatchElements(\n",
        "            min_batch_size=args.batch_size, max_batch_size=args.batch_size)\n",
        "        | 'Generate embeddings' >> beam.Map(\n",
        "            generate_embeddings, args.model_url, args.random_projection_matrix)\n",
        "        | 'Encode to tf example' >> beam.FlatMap(to_tf_example)\n",
        "        | 'Write to TFRecords files' >> beam.io.WriteToTFRecord(\n",
        "            file_path_prefix='{}/emb'.format(args.output_dir),\n",
        "            file_name_suffix='.tfrecords')\n",
        "    )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nlbQdiYNVvne"
      },
      "source": [
        "### Generating Random Projection Weight Matrix\n",
        "\n",
        "[Random projection](https://en.wikipedia.org/wiki/Random_projection) is a simple, yet powerful technique used to reduce the dimensionality of a set of points which lie in Euclidean space. For a theoretical background, see the [Johnson-Lindenstrauss lemma](https://en.wikipedia.org/wiki/Johnson%E2%80%93Lindenstrauss_lemma).\n",
        "\n",
        "Reducing the dimensionality of the embeddings with random projection means less time needed to build and query the ANN index.\n",
        "\n",
        "In this tutorial we use [Gaussian Random Projection](https://en.wikipedia.org/wiki/Random_projection#Gaussian_random_projection) from the [Scikit-learn](https://scikit-learn.org/stable/modules/random_projection.html#gaussian-random-projection) library."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1yw1xgtNVv52"
      },
      "outputs": [],
      "source": [
        "def generate_random_projection_weights(original_dim, projected_dim):\n",
        "  random_projection_matrix = None\n",
        "  random_projection_matrix = gaussian_random_matrix(\n",
        "      n_components=projected_dim, n_features=original_dim).T\n",
        "  print(\"A Gaussian random weight matrix was creates with shape of {}\".format(random_projection_matrix.shape))\n",
        "  print('Storing random projection matrix to disk...')\n",
        "  with open('random_projection_matrix', 'wb') as handle:\n",
        "    pickle.dump(random_projection_matrix, \n",
        "                handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
        "        \n",
        "  return random_projection_matrix"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aJZUfT3NE7kj"
      },
      "source": [
        "### Set parameters\n",
        "If you want to build an index using the original embedding space without random projection, set the `projected_dim` parameter to `None`. Note that this will slow down the indexing step for high-dimensional embeddings."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "77-Cow7uE74T"
      },
      "outputs": [],
      "source": [
        "model_url = 'https://tfhub.dev/google/nnlm-en-dim128/2' #@param {type:\"string\"}\n",
        "projected_dim = 64  #@param {type:\"number\"}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "On-MbzD922kb"
      },
      "source": [
        "### Run pipeline"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Y3I1Wv4i21yY"
      },
      "outputs": [],
      "source": [
        "import tempfile\n",
        "\n",
        "output_dir = tempfile.mkdtemp()\n",
        "original_dim = hub.load(model_url)(['']).shape[1]\n",
        "random_projection_matrix = None\n",
        "\n",
        "if projected_dim:\n",
        "  random_projection_matrix = generate_random_projection_weights(\n",
        "      original_dim, projected_dim)\n",
        "\n",
        "args = {\n",
        "    'job_name': 'hub2emb-{}'.format(datetime.utcnow().strftime('%y%m%d-%H%M%S')),\n",
        "    'runner': 'DirectRunner',\n",
        "    'batch_size': 1024,\n",
        "    'data_dir': 'corpus/*.txt',\n",
        "    'output_dir': output_dir,\n",
        "    'model_url': model_url,\n",
        "    'random_projection_matrix': random_projection_matrix,\n",
        "}\n",
        "\n",
        "print(\"Pipeline args are set.\")\n",
        "args"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iS9obmeP4ZOA"
      },
      "outputs": [],
      "source": [
        "print(\"Running pipeline...\")\n",
        "%time run_hub2emb(args)\n",
        "print(\"Pipeline is done.\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JAwOo7gQWvVd"
      },
      "outputs": [],
      "source": [
        "!ls {output_dir}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HVnee4e6U90u"
      },
      "source": [
        "Read some of the generated embeddings..."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-K7pGXlXOj1N"
      },
      "outputs": [],
      "source": [
        "embed_file = os.path.join(output_dir, 'emb-00000-of-00001.tfrecords')\n",
        "sample = 5\n",
        "\n",
        "# Create a description of the features.\n",
        "feature_description = {\n",
        "    'text': tf.io.FixedLenFeature([], tf.string),\n",
        "    'embedding': tf.io.FixedLenFeature([projected_dim], tf.float32)\n",
        "}\n",
        "\n",
        "def _parse_example(example):\n",
        "  # Parse the input `tf.Example` proto using the dictionary above.\n",
        "  return tf.io.parse_single_example(example, feature_description)\n",
        "\n",
        "dataset = tf.data.TFRecordDataset(embed_file)\n",
        "for record in dataset.take(sample).map(_parse_example):\n",
        "  print(\"{}: {}\".format(record['text'].numpy().decode('utf-8'), record['embedding'].numpy()[:10]))\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "agGoaMSgY8wN"
      },
      "source": [
        "## 3. Build the ANN Index for the Embeddings\n",
        "\n",
        "[ANNOY](https://github.com/spotify/annoy) (Approximate Nearest Neighbors Oh Yeah) is a C++ library with Python bindings to search for points in space that are close to a given query point. It also creates large read-only file-based data structures that are mapped into memory. It is built and used by [Spotify](https://www.spotify.com) for music recommendations. If you are interested you can play along with other alternatives to ANNOY such as [NGT](https://github.com/yahoojapan/NGT), [FAISS](https://github.com/facebookresearch/faiss), etc. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UcPDspU3WjgH"
      },
      "outputs": [],
      "source": [
        "def build_index(embedding_files_pattern, index_filename, vector_length, \n",
        "    metric='angular', num_trees=100):\n",
        "  '''Builds an ANNOY index'''\n",
        "\n",
        "  annoy_index = annoy.AnnoyIndex(vector_length, metric=metric)\n",
        "  # Mapping between the item and its identifier in the index\n",
        "  mapping = {}\n",
        "\n",
        "  embed_files = tf.io.gfile.glob(embedding_files_pattern)\n",
        "  num_files = len(embed_files)\n",
        "  print('Found {} embedding file(s).'.format(num_files))\n",
        "\n",
        "  item_counter = 0\n",
        "  for i, embed_file in enumerate(embed_files):\n",
        "    print('Loading embeddings in file {} of {}...'.format(i+1, num_files))\n",
        "    dataset = tf.data.TFRecordDataset(embed_file)\n",
        "    for record in dataset.map(_parse_example):\n",
        "      text = record['text'].numpy().decode(\"utf-8\")\n",
        "      embedding = record['embedding'].numpy()\n",
        "      mapping[item_counter] = text\n",
        "      annoy_index.add_item(item_counter, embedding)\n",
        "      item_counter += 1\n",
        "      if item_counter % 100000 == 0:\n",
        "        print('{} items loaded to the index'.format(item_counter))\n",
        "\n",
        "  print('A total of {} items added to the index'.format(item_counter))\n",
        "\n",
        "  print('Building the index with {} trees...'.format(num_trees))\n",
        "  annoy_index.build(n_trees=num_trees)\n",
        "  print('Index is successfully built.')\n",
        "  \n",
        "  print('Saving index to disk...')\n",
        "  annoy_index.save(index_filename)\n",
        "  print('Index is saved to disk.')\n",
        "  print(\"Index file size: {} GB\".format(\n",
        "    round(os.path.getsize(index_filename) / float(1024 ** 3), 2)))\n",
        "  annoy_index.unload()\n",
        "\n",
        "  print('Saving mapping to disk...')\n",
        "  with open(index_filename + '.mapping', 'wb') as handle:\n",
        "    pickle.dump(mapping, handle, protocol=pickle.HIGHEST_PROTOCOL)\n",
        "  print('Mapping is saved to disk.')\n",
        "  print(\"Mapping file size: {} MB\".format(\n",
        "    round(os.path.getsize(index_filename + '.mapping') / float(1024 ** 2), 2)))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AgyOQhUq6FNE"
      },
      "outputs": [],
      "source": [
        "embedding_files = \"{}/emb-*.tfrecords\".format(output_dir)\n",
        "embedding_dimension = projected_dim\n",
        "index_filename = \"index\"\n",
        "\n",
        "!rm {index_filename}\n",
        "!rm {index_filename}.mapping\n",
        "\n",
        "%time build_index(embedding_files, index_filename, embedding_dimension)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ic31Tm5cgAd5"
      },
      "outputs": [],
      "source": [
        "!ls"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "maGxDl8ufP-p"
      },
      "source": [
        "## 4. Use the Index for Similarity Matching\n",
        "Now we can use the ANN index to find news headlines that are semantically close to an input query."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_dIs8W78fYPp"
      },
      "source": [
        "### Load the index and the mapping files"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jlTTrbQHayvb"
      },
      "outputs": [],
      "source": [
        "index = annoy.AnnoyIndex(embedding_dimension)\n",
        "index.load(index_filename, prefault=True)\n",
        "print('Annoy index is loaded.')\n",
        "with open(index_filename + '.mapping', 'rb') as handle:\n",
        "  mapping = pickle.load(handle)\n",
        "print('Mapping file is loaded.')\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "y6liFMSUh08J"
      },
      "source": [
        "### Similarity matching method"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mUxjTag8hc16"
      },
      "outputs": [],
      "source": [
        "def find_similar_items(embedding, num_matches=5):\n",
        "  '''Finds similar items to a given embedding in the ANN index'''\n",
        "  ids = index.get_nns_by_vector(\n",
        "  embedding, num_matches, search_k=-1, include_distances=False)\n",
        "  items = [mapping[i] for i in ids]\n",
        "  return items"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hjerNpmZja0A"
      },
      "source": [
        "### Extract embedding from a given query"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a0IIXzfBjZ19"
      },
      "outputs": [],
      "source": [
        "# Load the TF-Hub model\n",
        "print(\"Loading the TF-Hub model...\")\n",
        "%time embed_fn = hub.load(model_url)\n",
        "print(\"TF-Hub model is loaded.\")\n",
        "\n",
        "random_projection_matrix = None\n",
        "if os.path.exists('random_projection_matrix'):\n",
        "  print(\"Loading random projection matrix...\")\n",
        "  with open('random_projection_matrix', 'rb') as handle:\n",
        "    random_projection_matrix = pickle.load(handle)\n",
        "  print('random projection matrix is loaded.')\n",
        "\n",
        "def extract_embeddings(query):\n",
        "  '''Generates the embedding for the query'''\n",
        "  query_embedding =  embed_fn([query])[0].numpy()\n",
        "  if random_projection_matrix is not None:\n",
        "    query_embedding = query_embedding.dot(random_projection_matrix)\n",
        "  return query_embedding\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kCoCNROujEIO"
      },
      "outputs": [],
      "source": [
        "extract_embeddings(\"Hello Machine Learning!\")[:10]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "koINo8Du--8C"
      },
      "source": [
        "### Enter a query to find the most similar items"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "wC0uLjvfk5nB"
      },
      "outputs": [],
      "source": [
        "#@title { run: \"auto\" }\n",
        "query = \"confronting global challenges\" #@param {type:\"string\"}\n",
        "\n",
        "print(\"Generating embedding for the query...\")\n",
        "%time query_embedding = extract_embeddings(query)\n",
        "\n",
        "print(\"\")\n",
        "print(\"Finding relevant items in the index...\")\n",
        "%time items = find_similar_items(query_embedding, 10)\n",
        "\n",
        "print(\"\")\n",
        "print(\"Results:\")\n",
        "print(\"=========\")\n",
        "for item in items:\n",
        "  print(item)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TkRSqs77tDuX"
      },
      "source": [
        "## Want to learn more?\n",
        "\n",
        "You can learn more about TensorFlow at [tensorflow.org](https://www.tensorflow.org/) and see the TF-Hub API documentation at [tensorflow.org/hub](https://www.tensorflow.org/hub/). Find available TensorFlow Hub models at [tfhub.dev](https://tfhub.dev/) including more text embedding models and image feature vector models.\n",
        "\n",
        "Also check out the [Machine Learning Crash Course](https://developers.google.com/machine-learning/crash-course/) which is Google's fast-paced, practical introduction to machine learning."
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "ACbjNjyO4f_8",
        "g6pXBVxsVUbm"
      ],
      "name": "tf2_semantic_approximate_nearest_neighbors.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
